from experiment_utils.largestconnectedcomponent import lcc_dataset
from utils.load_datasets import load_data,data_information
from experiment_utils.sdrf_cuda import sdrf_BFc,sdrf_JTc,sdrf_JLc,sdrf_AFc
from utils.seeds import val_seeds
from utils.splits import set_train_val_test_split,set_train_val_test_split_frac
from experiment_utils.experimentclass import Experiment

from torch_geometric.data import Data
import torch
import torch.nn.functional as F
import torch_geometric
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device: ", device)
import numpy as np

from tqdm import tqdm 
import os
import json

import wandb
import sys

"""
Parameters for the experiment
"""

os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS"] = "false"

datasetname = "Texas"#sys.argv[1]
results_dir = "results"
rewiring_run = True#eval(sys.argv[2])
make_undirected = True
Curvature_type = "BFc_w4cycle"#sys.argv[3]


dataset,data,G = load_data(datasetname)
dataset_lcc = lcc_dataset(dataset,to_undirected = make_undirected)
data_lcc = dataset_lcc[0]

data_information(dataset_lcc,data_lcc)

path_hyper = "experiment_utils/hyperparameters/"
path_wandblog = ""

with open(os.path.join(path_hyper,'hyperparameters_Neurips_2.json'), 'r') as file:
     sweep_configuration = json.load(file)
     sweep_configuration =sweep_configuration.get(datasetname, {})

sweep_configuration["name"] = datasetname + '_' + Curvature_type

def objective(config,rewire = False):
    accuracies = []
    test_acc = []
    if rewire:
        print("===Starting Rewiring===")
        G_rewired,edge_index_rewired = create_rewired_edge_index(data_lcc,config,intermediate_node=int_node,remove_edges=True,curvaturetype=Curvature_type)
        print(" ")
    
    print(" == Starting Runs == ")
    
    for idx_k,k in enumerate(val_seeds):
        if datasetname == "Cora" or datasetname == "Citeseer" or datasetname == "Pubmed":
            data_undirected_split = set_train_val_test_split(k,data_lcc)
        else:
            data_undirected_split = set_train_val_test_split_frac(k,data_lcc,0.2,0.2)

        if rewire:
            data_undirected_split.edge_index = edge_index_rewired

        data_undirected_split.to(device)

        Exp = Experiment(device,datasetname,dataset_lcc,data_undirected_split,config)

        
        counter = 0
        for epoch in range(1, Exp.epoch):
            loss = Exp.train()
            val = Exp.validate()
            if epoch ==1:
                best_val = val
            elif epoch > 1 and val > best_val:
                best_val = val
                counter = 0
            else:
                counter += 1
            if counter > 100:
                break  
        final_accuracy = Exp.validate()
        final_test_acc = Exp.test()
        accuracies.append(final_accuracy)
        test_acc.append(final_test_acc)
    print("")
    return np.mean(np.array(accuracies)),np.mean(np.array(test_acc))

def create_rewired_edge_index(data,hyperparameters,intermediate_node,remove_edges,curvaturetype: str ):
    if curvaturetype == "BFc_w4cycle":
        G_rewired,_ = sdrf_BFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            int_node = intermediate_node,
            is_undirected=data.is_undirected(),
            fcc = True,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "BFc_no4cycle":
        G_rewired,_ = sdrf_BFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            int_node = intermediate_node,
            is_undirected=data.is_undirected(),
            fcc = False,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "JTc":
        G_rewired,_ = sdrf_JTc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "JLc":
        G_rewired,_ = sdrf_JLc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=hyperparameters["C+"],
            tau=hyperparameters["tau"], 
            is_undirected=data.is_undirected(),
            progress_bar = False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "AFc_3":
        G_rewired,_ = sdrf_AFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=-hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            k = 3.,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    elif curvaturetype == "AFc_4":
        G_rewired,_ = sdrf_AFc(
            data,
            loops=hyperparameters["loops"],
            remove_edges=remove_edges,
            removal_bound=-hyperparameters["C+"],
            tau=hyperparameters["tau"],
            is_undirected=data.is_undirected(),
            k = 4,
            progress_bar= False
                        )
        edge_index_rewired = torch_geometric.utils.to_undirected(torch.tensor(list(G_rewired.edges)).t())
    
    return G_rewired,edge_index_rewired 
 
def main():
    wandb.init(dir = path_wandblog)
    acc,test_acc = objective(wandb.config,rewiring_run)
    wandb.log({"mean accuracy": acc, "mean test accuracy": test_acc})


sweep_id = wandb.sweep(sweep=sweep_configuration, project="Curvature_Neurips_NodeClass_Final")
wandb.agent(sweep_id, function=main,count = 1000)

    